use crate::matrix::{apply_matrix, vec_add};
use crate::mds::MDSMatrices;
use crate::quintic_s_box;
use ff::{Field, ScalarEngine};

// - Compress constants by pushing them back through linear layers and through the identity components of partial layers.
// - As a result, constants need only be added after each S-box.
pub(crate) fn compress_round_constants<E: ScalarEngine>(
    width: usize,
    full_rounds: usize,
    partial_rounds: usize,
    round_constants: &Vec<E::Fr>,
    mds_matrices: &MDSMatrices<E>,
    partial_preprocessed: usize,
) -> Vec<E::Fr> {
    let mds_matrix = &mds_matrices.m;
    let inverse_matrix = &mds_matrices.m_inv;

    let mut res = Vec::new();

    let round_keys = |r: usize| &round_constants[r * width..(r + 1) * width];

    let half_full_rounds = full_rounds / 2; // Not half-full rounds; half full-rounds.

    // First round constants are unchanged.
    res.extend(round_keys(0));

    let unpreprocessed = partial_rounds - partial_preprocessed;

    // Post S-box adds for the first set of full rounds should be 'inverted' from next round.
    // The final round is skipped when fully preprocessing because that value must be obtained from the result of preprocesing the partial rounds.
    let end = if unpreprocessed > 0 {
        half_full_rounds
    } else {
        half_full_rounds - 1
    };
    for i in 0..end {
        let next_round = round_keys(i + 1); // First round was added before any S-boxes.
        let inverted = apply_matrix::<E>(inverse_matrix, next_round);
        res.extend(inverted);
    }

    // The plan:
    // - Work backwards from last row in this group
    // - Invert the row.
    // - Save first constant (corresponding to the one S-box performed).
    // - Add inverted result to previous row.
    // - Repeat until all partial round key rows have been consumed.
    // - Extend the preprocessed result by the final resultant row.
    // - Move the accumulated list of single round keys to the preprocessed result.
    //   - (Last produced should be first applied, so either pop until empty, or reverse and extend, etc.

    // `partial_keys` will accumulate the single post-S-box constant for each partial-round, in reverse order.
    let mut partial_keys: Vec<E::Fr> = Vec::new();

    let final_round = half_full_rounds + partial_rounds;
    let final_round_key = round_keys(final_round).to_vec();

    // `round_acc` holds the accumulated result of inverting and adding subsequent round constants (in reverse).
    let round_acc = (0..partial_preprocessed)
        .map(|i| round_keys(final_round - i - 1))
        .fold(final_round_key, |acc, previous_round_keys| {
            let mut inverted = apply_matrix::<E>(inverse_matrix, &acc);

            partial_keys.push(inverted[0]);
            inverted[0] = E::Fr::zero();

            vec_add::<E>(&previous_round_keys, &inverted)
        });

    // Everything in here is dev-driven testing.
    // Dev test case only checks one deep.
    if partial_preprocessed == 1 {
        // Check assumptions about how the fold calculating round_acc  manifested.

        // The last round containing unpreprocessed constants which should be compressed.
        let terminal_constants_round = half_full_rounds + partial_rounds;

        // Constants from the last round (of two) which should be compressed.
        // T
        let terminal_round_keys = round_keys(terminal_constants_round);

        // Constants from the first round (of two) which should be compressed.
        // I
        let initial_round_keys = round_keys(terminal_constants_round - 1);

        // M^-1(T)
        let mut inv = apply_matrix::<E>(inverse_matrix, terminal_round_keys);

        // M^-1(T)[0]
        let pk = inv[0];

        // M^-1(T) - pk (kinda)
        inv[0] = E::Fr::zero();

        // (M^-1(T) - pk) - I
        let result_key = vec_add::<E>(&initial_round_keys, &inv);

        assert_eq!(&result_key, &round_acc, "Acc assumption failed.");
        assert_eq!(pk, partial_keys[0], "Partial-key assumption failed.");
        assert_eq!(
            1,
            partial_keys.len(),
            "Partial-keys length assumption failed."
        );

        ////////////////////////////////////////////////////////////////////////////////
        // Shared between branches, arbitrary initial state representing the output of a previous round's S-Box layer.
        // X
        let initial_state = vec![E::Fr::one(); width];

        ////////////////////////////////////////////////////////////////////////////////
        // Compute one step with the given (unpreprocessed) constants.

        // ARK
        // I + X
        let mut q_state = vec_add::<E>(initial_round_keys, &initial_state);

        // S-Box (partial layer)
        // S((I + X)[0]) = S(I[0] + X[0])
        quintic_s_box::<E>(&mut q_state[0], None, None);

        // Mix with mds_matrix
        let mixed = apply_matrix::<E>(mds_matrix, &q_state);

        // Ark
        let plain_result = vec_add::<E>(terminal_round_keys, &mixed);

        ////////////////////////////////////////////////////////////////////////////////
        // Compute the same step using the preprocessed constants.
        // M'(initial_state) + (inverted_id - initial_state) = inverted_id
        //let initial_state1 = apply_matrix::<E>(&m_prime, &initial_state);
        let mut p_state = vec_add::<E>(&result_key, &initial_state);

        // In order for the S-box result to be correct, it must have the same input as in the plain path.
        // That means its input (the first component of the state) must have been constructed by
        // adding the same single round constant in that position.
        // NOTE: this asssertion uncovered a bug which was causing failure.
        assert_eq!(
            &result_key[0], &initial_round_keys[0],
            "S-box inputs did not match."
        );

        quintic_s_box::<E>(&mut p_state[0], None, Some(&pk));

        let preprocessed_result = apply_matrix::<E>(&mds_matrix, &p_state);

        assert_eq!(
            plain_result, preprocessed_result,
            "Single preprocessing step couldn't be verified."
        );
    }

    for i in 1..unpreprocessed {
        res.extend(round_keys(half_full_rounds + i));
    }
    res.extend(apply_matrix::<E>(inverse_matrix, &round_acc));

    while let Some(x) = partial_keys.pop() {
        res.push(x)
    }

    // Post S-box adds for the first set of full rounds should be 'inverted' from next round.
    for i in 1..(half_full_rounds) {
        let start = half_full_rounds + partial_rounds;
        let next_round = round_keys(i + start);
        let inverted = apply_matrix::<E>(inverse_matrix, next_round);
        res.extend(inverted);
    }

    res
}
